Source code for lmcsc.obversation_generator

from typing import Union, List
from copy import deepcopy


[docs] class BaseObversationGenerator:
[docs] def reorder(self, beam_idx: List[int]) -> None: raise NotImplementedError
[docs] def step(self, token_lists: List[List[Union[str, bytes]]]) -> None: raise NotImplementedError
[docs] def show_steps(self) -> None: raise NotImplementedError
[docs] def get_observed_sequences(self) -> List[str]: raise NotImplementedError
[docs] class NextObversationGenerator(BaseObversationGenerator): r""" This class records the progress of the beam search, tracking what has been generated so far and what characters are yet to be generated. Parameters: src (`List[str]`): The source sequences. n_beam (`int`): The number of beams for beam search. n_observed_chars (`int`): The number of characters to observe. is_bytes_level (`bool`): Whether to operate at the byte level. verbose (`bool`, *optional*, defaults to `False`): Whether to enable verbose mode. Attributes: src (`List[Union[str, bytes]]`): The source sequences, potentially encoded to bytes. n_beam (`int`): The number of beams. n_observed_chars (`int`): The number of characters to observe. is_bytes_level (`bool`): Whether operating at byte level. verbose (`bool`): Verbose mode flag. batch_predicts (`List[List[Union[str, bytes]]]`): Predictions for each beam in each batch. batch_steps (`List[List[int]]`): Steps taken for each beam in each batch. batch_verbose_steps (`List[List[List[Union[str, bytes]]]]`): Verbose steps for each beam in each batch. is_finished (`List[List[bool]]`): Flags indicating if each beam in each batch is finished. """ def __init__(self, src, n_beam, n_observed_chars, is_bytes_level, verbose=False): self.src = src self.n_beam = n_beam self.n_observed_chars = n_observed_chars self.is_bytes_level = is_bytes_level self.verbose = verbose if is_bytes_level: self.src = [s.encode("utf-8") for s in src] self.batch_predicts = [[b""] * n_beam for _ in range(len(src))] # TODO: handle bytes level else: self.batch_predicts = [[""] * n_beam for _ in range(len(src))] self.batch_steps = [[0] * n_beam for _ in range(len(src))] self.insert_counters = [[0] * n_beam for _ in range(len(src))] self.verbose = verbose if self.verbose: self.batch_verbose_steps = [ [[] for _ in range(n_beam)] for _ in range(len(src)) ] self.is_finished = [[False] * n_beam for _ in range(len(src))]
[docs] def reorder(self, beam_idx: List[int]) -> None: """ Reorders the beams based on the given indices. Args: beam_idx (List[int]): The indices to reorder the beams. """ self.batch_predicts = [ [self.batch_predicts[i][b] for b in beam] for i, beam in enumerate(beam_idx) ] self.batch_steps = [ [self.batch_steps[i][b] for b in beam] for i, beam in enumerate(beam_idx) ] self.is_finished = [ [self.is_finished[i][b] for b in beam] for i, beam in enumerate(beam_idx) ] self.insert_counters = [ [self.insert_counters[i][b] for b in beam] for i, beam in enumerate(beam_idx) ] if self.verbose: self.batch_verbose_steps = [ [ deepcopy(self.batch_verbose_steps[i][b]) for b in beam ] for i, beam in enumerate(beam_idx) ]
[docs] def step(self, token_lists: List[List[Union[str, bytes]]], step_lists: List[List[int]]): """ Performs a step in the beam search process. Args: token_lists (List[List[Union[str, bytes]]]): The tokens generated in this step. step_lists (List[List[int]]): The corresponding steps for each token. """ for i, (tokens, steps) in enumerate(zip(token_lists, step_lists)): for j, (token, step) in enumerate(zip(tokens, steps)): if self.is_finished[i][j]: continue if token not in {"<|endoftext|>", "</s>", "[SEP]"}: self.batch_predicts[i][j] += token if step == 0: self.insert_counters[i][j] += 1 else: # reset the insert counter self.insert_counters[i][j] = 0 if self.insert_counters[i][j] > 1: # force to move forward step = 1 if self.is_bytes_level: src = self.src[i] while True: try: src[self.batch_steps[i][j] + step:].decode("utf-8") break except: step += 1 self.insert_counters[i][j] = 0 self.batch_steps[i][j] += step if self.verbose: self.batch_verbose_steps[i][j].append(token) else: self.is_finished[i][j] = True
[docs] def show_steps(self) -> None: """ Displays the steps taken in the beam search process. """ batch_predicts = self.batch_verbose_steps if self.verbose else self.batch_predicts for predicts in batch_predicts: for predict in predicts: try: if self.is_bytes_level: if self.verbose: print([s.decode("utf-8") for s in predict]) else: print(predict.decode("utf-8")) else: print(predict) except: print(predict) print()
[docs] def get_observed_sequences(self) -> List[str]: """ Retrieves the observed sequences from the beam search process. Returns: List[str]: The observed sequences for each beam in each batch. """ batch_observed_sequences = [] n_observed_chars = self.n_observed_chars for batch_idx, steps in enumerate(self.batch_steps): observed_sequences = [] src = self.src[batch_idx] for step in steps: # In fact, there is a bug here. # We assume that all Chinese characters are 3 bytes long. # However, some Chinese characters are 4 bytes long. # When a 4 bytes character correct to 3 bytes, it will introduce a garbled character. if self.is_bytes_level: try: token = src[step:].decode("utf-8") observed_sequence = token[:n_observed_chars] except: observed_sequence = src[ step : step + (n_observed_chars * 3) ] else: observed_sequence = src[step : step + n_observed_chars] observed_sequence = observed_sequence.replace(" ", "▁") observed_sequences.append(observed_sequence) batch_observed_sequences.append(observed_sequences) return batch_observed_sequences